{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3bd7bf44-5271-4c15-b0b0-9ffb328852a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "cannot found `mindformers.experimental`, please install dev version by\n",
      "`pip install git+https://gitee.com/mindspore/mindformers` \n",
      "or remove mindformers by \n",
      "`pip uninstall mindformers`\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 1.294 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import os\n",
    "from tqdm.notebook import tqdm\n",
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import ops\n",
    "from mindnlp.transformers import (\n",
    "    BertGenerationTokenizer,\n",
    "    BertGenerationDecoder,\n",
    "    BertGenerationConfig,\n",
    "    CLIPModel,\n",
    "    CLIPTokenizer\n",
    ")\n",
    "from loaders.ZO_Clip_loaders import cifar10_single_isolated_class_loader\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from mindspore import context\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9edcb918-e4c4-4a9d-86d8-4bd22f034688",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_for_clip(batch_sentences, tokenizer):\n",
    "    # 使用CLIPTokenizer直接处理\n",
    "    inputs = tokenizer(\n",
    "        batch_sentences,\n",
    "        padding=True,\n",
    "        truncation=True,\n",
    "        max_length=77,\n",
    "        return_tensors=\"ms\"\n",
    "    )\n",
    "    return inputs.input_ids\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7f5c81bd-7c4a-4ec6-9592-45b993b01443",
   "metadata": {},
   "outputs": [],
   "source": [
    "def greedysearch_generation_topk(clip_embed, bert_model, batch_size=32):\n",
    "    # 处理多个样本\n",
    "    N = clip_embed.shape[0]\n",
    "    max_len = 77\n",
    "\n",
    "    # 初始化batch的target序列\n",
    "    target_lists = [[berttokenizer.bos_token_id] for _ in range(N)]\n",
    "    top_k_lists = [[] for _ in range(N)]\n",
    "    bert_model.set_train(False)\n",
    "\n",
    "    for i in range(max_len):\n",
    "        # 批量处理target序列\n",
    "        targets = ms.Tensor(target_lists,dtype=ms.int64)\n",
    "\n",
    "        # 批量生成position_ids\n",
    "        position_ids = ms.Tensor(np.arange(targets.shape[1])[None].repeat(N, axis=0), ms.int32)\n",
    "\n",
    "        # 批量生成attention_mask\n",
    "        attention_mask = ops.ones((N,targets.shape[1]), dtype=ms.int32)\n",
    "\n",
    "\n",
    "        # 批量推理\n",
    "        out = bert_model(\n",
    "            input_ids=targets,\n",
    "            attention_mask=attention_mask,\n",
    "            position_ids=position_ids,\n",
    "            encoder_hidden_states=clip_embed,\n",
    "        )\n",
    "\n",
    "        # 批量处理预测结果\n",
    "        pred_idxs = out.logits.argmax(axis=2)[:, -1].astype(ms.int64)\n",
    "        _, top_k = ops.topk(out.logits, dim=2, k=35)\n",
    "\n",
    "        # 更新每个样本的序列\n",
    "        for j in range(N):\n",
    "            target_lists[j].append(pred_idxs[j].item())\n",
    "            top_k_lists[j].append(top_k[j, -1])\n",
    "\n",
    "        # 检查是否所有样本都已完成\n",
    "        if all(len(t) >= 10 for t in target_lists):\n",
    "            break\n",
    "\n",
    "    # 处理返回结果\n",
    "    results = []\n",
    "    for i in range(N):\n",
    "        top_k_tensor = ops.concat(top_k_lists[i])\n",
    "        target_tensor = ms.Tensor(target_lists[i], dtype=ms.int64)\n",
    "        results.append((target_tensor, top_k_tensor))\n",
    "\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "066b8086-1a9f-40fc-bc66-acde98a8124c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def image_decoder(clip_model, berttokenizer, image_loaders=None, bert_model=None):\n",
    "    print('t2')\n",
    "    ablation_splits = [\n",
    "        ['airplane', 'automobile', 'truck', 'horse', 'cat', 'bird', 'ship', 'dog', 'deer', 'frog'],\n",
    "        ['airplane', 'automobile', 'truck', 'bird', 'ship', 'frog', 'deer', 'dog', 'horse', 'cat']\n",
    "    ]\n",
    "\n",
    "    auc_list_sum = []\n",
    "    for split in ablation_splits:\n",
    "        print(\"处理split:\", split)\n",
    "        seen_labels = split[:6]\n",
    "        seen_descriptions = [f\"This is a photo of a {label}\" for label in seen_labels]\n",
    "        targets = ms.Tensor(6000 * [0] + 4000 * [1], dtype=ms.int32)\n",
    "        max_num_entities = 0\n",
    "        ood_probs_sum = []\n",
    "\n",
    "        # 批量处理每个语义标签\n",
    "        for semantic_label in split:\n",
    "            print(f\"处理类别: {semantic_label}\")\n",
    "            loader = image_loaders[semantic_label]\n",
    "\n",
    "            for batch_data in tqdm(loader.create_dict_iterator()):\n",
    "                # 获取整个批次的图像\n",
    "                batch_images = batch_data[\"image\"]\n",
    "\n",
    "                # 批量获取CLIP图像特征\n",
    "                clip_model.set_train(False)\n",
    "                clip_out = clip_model.get_image_features(pixel_values=batch_images)\n",
    "                clip_extended_embed = ops.repeat_elements(clip_out, rep=2, axis=1)\n",
    "                clip_extended_embed = ops.expand_dims(clip_extended_embed, 1)\n",
    "\n",
    "                # 批量生成文本\n",
    "                batch_results = greedysearch_generation_topk(clip_extended_embed, bert_model)\n",
    "                del clip_extended_embed\n",
    "                del clip_out\n",
    "\n",
    "                # 批量处理生成的文本\n",
    "                batch_target_tokens = []\n",
    "                batch_topk_tokens = []\n",
    "                \n",
    "                # 批量解码token\n",
    "                for target_list, topk_list in batch_results:\n",
    "                    target_tokens = [berttokenizer.decode(int(pred_idx.asnumpy())) for pred_idx in target_list]\n",
    "                    topk_tokens = [berttokenizer.decode(int(pred_idx.asnumpy())) for pred_idx in topk_list]\n",
    "                    batch_target_tokens.append(target_tokens)\n",
    "                    batch_topk_tokens.append(topk_tokens)\n",
    "\n",
    "                # 批量找出唯一实体\n",
    "                batch_unique_entities = []\n",
    "                for topk_tokens in batch_topk_tokens:\n",
    "                    unique_entities = list(set(topk_tokens) - set(seen_labels))\n",
    "                    batch_unique_entities.append(unique_entities)\n",
    "                    max_num_entities = max(max_num_entities, len(unique_entities))\n",
    "\n",
    "                # 批量构建描述\n",
    "                batch_all_desc = []\n",
    "                for unique_entities in batch_unique_entities:\n",
    "                    all_desc = seen_descriptions + [f\"This is a photo of a {label}\" for label in unique_entities]\n",
    "                    batch_all_desc.append(all_desc)\n",
    "\n",
    "                # 批量编码文本\n",
    "                batch_all_desc_ids = [tokenize_for_clip(all_desc, cliptokenizer) for all_desc in batch_all_desc]\n",
    "                \n",
    "                # 批量计算图像特征\n",
    "                image_features = clip_model.get_image_features(pixel_values=batch_images)\n",
    "                image_features = image_features / ops.norm(image_features, dim=-1, keepdim=True)\n",
    "\n",
    "                for b_idx in range(len(batch_results)):\n",
    "                    text_features = clip_model.get_text_features(input_ids=batch_all_desc_ids[b_idx])\n",
    "                    text_features = text_features / ops.norm(text_features, dim=-1, keepdim=True)\n",
    "                    \n",
    "                    similarity = 100.0 * (image_features[b_idx:b_idx+1] @ text_features.T)\n",
    "                    zeroshot_probs = ops.softmax(similarity, axis=-1).squeeze()\n",
    "                    \n",
    "                    # 累积OOD概率\n",
    "                    ood_prob_sum = float(ops.sum(zeroshot_probs[6:]).asnumpy())\n",
    "                    ood_probs_sum.append(ood_prob_sum)\n",
    "                del batch_target_tokens\n",
    "                del batch_topk_tokens\n",
    "                del batch_unique_entities\n",
    "                del batch_all_desc\n",
    "                del image_features\n",
    "\n",
    "        # 计算当前split的AUC分数\n",
    "        auc_sum = roc_auc_score(targets.asnumpy(), np.array(ood_probs_sum))\n",
    "        print('当前split的sum_ood AUROC={}'.format(auc_sum))\n",
    "        auc_list_sum.append(auc_sum)\n",
    "\n",
    "    print('所有AUC分数:', auc_list_sum)\n",
    "    print('AUC均值和标准差:', np.mean(auc_list_sum), np.std(auc_list_sum))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "eb15febd-3cb5-4f25-9015-fc6c0bbe750e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_args_in_notebook():\n",
    "    args = argparse.Namespace(\n",
    "        trained_path='./trained_models/COCO/'\n",
    "    )\n",
    "    return args\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d6b3ad72-3685-4a2a-a655-198d785e4432",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "BertGenerationDecoder has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
      "  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
      "  - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t0\n",
      "t1\n",
      "t2\n",
      "处理split: ['airplane', 'automobile', 'truck', 'horse', 'cat', 'bird', 'ship', 'dog', 'deer', 'frog']\n",
      "处理类别: airplane\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "64747b8e9f9740e1b772c721e3b9e99f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: automobile\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ba3926297d994ba79ba5518c4435969f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: truck\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c1eb2fd7c0a64d328764e9a03f8cb7f2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: horse\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b5db460fb60d4de5ab4e12107f33f062",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: cat\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "035cb664d4f4413c957ae58b88d4d46b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: bird\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84d71faf898b409586b7b3ea95a8a1a7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: ship\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8f2b40df12334cd59de24f32279a0449",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: dog\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "326d0ea3aa1d4f688cb7cf280e4eb57e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: deer\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "52306cbd59364f2aa888762639b3e071",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: frog\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3370ea050849489cbc0d3ce4ef06315f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前split的sum_ood AUROC=0.925871875\n",
      "处理split: ['airplane', 'automobile', 'truck', 'bird', 'ship', 'frog', 'deer', 'dog', 'horse', 'cat']\n",
      "处理类别: airplane\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e00900edfd9a44f59805dcdaad2036f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: automobile\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df418c72de5e44cab5137272f4aba838",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: truck\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0309ee0d2595416ca33f73e7fe2a9d24",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: bird\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "964c19c4af7d42779e3c0fad35fbbc8b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: ship\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3cbb961cc0484e8a9d1b0ca271db744c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: frog\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "122e0bd2c8db4184ad93f398879bdb81",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: deer\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "afd942af7e1c4d4c9a79801849f697ca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: dog\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "20997f66fcf04efa9b726d738ec2c3b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: horse\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cac05b8b1ed445278729a3608d43c309",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "处理类别: cat\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5cd07d19fa954f53828b6d9fcfd64831",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前split的sum_ood AUROC=0.9849705\n",
      "所有AUC分数: [0.925871875, 0.9849705]\n",
      "AUC均值和标准差: 0.9554211875 0.02954931249999998\n"
     ]
    }
   ],
   "source": [
    "if __name__ == '__main__':\n",
    "    # 判断是否在notebook环境\n",
    "    if 'ipykernel' in sys.modules or 'IPython' in sys.modules:\n",
    "        # notebook环境\n",
    "        args = get_args_in_notebook()\n",
    "        context.set_context(device_target=\"Ascend\")\n",
    "    else:\n",
    "        # 命令行环境\n",
    "        parser = argparse.ArgumentParser()\n",
    "        parser.add_argument('--trained_path', type=str, default='./trained_models/COCO/')\n",
    "        args = parser.parse_args()\n",
    "        context.set_context(device_target=\"Ascend\")\n",
    "\n",
    "    args.saved_model_path = args.trained_path + '/ViT-B32/'\n",
    "\n",
    "    if not os.path.exists(args.saved_model_path):\n",
    "        os.makedirs(args.saved_model_path)\n",
    "\n",
    "    # 初始化tokenizers\n",
    "    berttokenizer = BertGenerationTokenizer.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder')\n",
    "\n",
    "    # 加载CLIP模型和tokenizer\n",
    "    model_name = 'openai/clip-vit-base-patch32'\n",
    "    try:\n",
    "        clip_model = CLIPModel.from_pretrained(model_name)\n",
    "        cliptokenizer = CLIPTokenizer.from_pretrained(model_name)\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading model from mirror, trying direct download: {e}\")\n",
    "        clip_model = CLIPModel.from_pretrained(model_name)\n",
    "        cliptokenizer = CLIPTokenizer.from_pretrained(model_name)\n",
    "\n",
    "    # 初始化模型\n",
    "    if (not os.path.exists(f\"{args.saved_model_path}/decoder_model\")):\n",
    "        bert_config = BertGenerationConfig.from_pretrained(\"google/bert_for_seq_generation_L-24_bbc_encoder\")\n",
    "        bert_config.is_decoder = True\n",
    "        bert_config.add_cross_attention = True\n",
    "        bert_config.return_dict = True\n",
    "        bert_model = BertGenerationDecoder.from_pretrained(\"google/bert_for_seq_generation_L-24_bbc_encoder\",\n",
    "                                                               config=bert_config)\n",
    "    else:\n",
    "        bert_model = BertGenerationDecoder.from_pretrained(f\"{args.saved_model_path}/decoder_model\")\n",
    "\n",
    "    # print(bert_model)\n",
    "    print(\"t0\")\n",
    "    cifar10_loaders = cifar10_single_isolated_class_loader()\n",
    "    print(\"t1\")\n",
    "    image_decoder(clip_model, berttokenizer, image_loaders=cifar10_loaders,bert_model=bert_model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4d05ac7-a091-4431-8c3c-ea0b380b5b47",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
